/*
 *  Rapfi, a Gomoku/Renju playing engine supporting piskvork protocol.
 *  Copyright (C) 2022  Rapfi developers
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#pragma once

#include "../core/pos.h"
#include "../core/types.h"

#include <cassert>
#include <memory>
#include <vector>

namespace Tuning {

/// Result represents the outcome of a finished game (based on current side to move).
enum Result : uint8_t { RESULT_LOSS, RESULT_DRAW, RESULT_WIN, RESULT_UNKNOWN };

/// One move in the multi-pv.
struct PVMove
{
    Pos  move;
    Eval eval;
};

/// DataEntry struct represents a raw training data entry that generated by running
/// a match between two engines. It contains the complete position (represented as
/// a sequence of move played by each side), move outputed by the current side to move,
/// board size, rule, game outcome and an optional policy target.
struct DataEntry
{
    std::vector<Pos> position;   // move sequence that representing a position
    Pos              move;       // best move output by the engine
    Eval             eval;       // evaluation of the position (side to move pov)
    uint8_t          boardsize;  // size of a square board
    Rule             rule;       // rule of the game entry
    Result           result;     // game result: 0=loss, 1=draw, 2=win (side to move pov)
    enum MoveDataTag : uint8_t {
        NO_MOVE_DATA,        // no move data is available
        POLICY_ARRAY_FLOAT,  // float policy probability of size [boardsize^2+1]
        POLICY_ARRAY_INT16,  // int16-quant policy probability of size [boardsize^2+1]
        MULTIPV_BEGIN,       // an array of EXTRA multi-pv moves and evaluations
        // The total number of multi-pv is given as (moveDataTag - MULTIPV_BEGIN + 2).
    } moveDataTag = NO_MOVE_DATA;  // type tag of the move data
    union {
        void    *moveData = nullptr;  // pointer to the optional move data
        float   *policyF32;           // row-major policy target of size [boardsize^2+1]
        int16_t *policyI16;           // row-major policy target of size [boardsize^2+1]
        PVMove  *multiPvMoves;        // list of multi-pv moves output by the engine
    };

    DataEntry() = default;
    DataEntry(std::vector<Pos> pos,
              uint8_t          boardsize,
              Rule             rule,
              Result           result      = RESULT_UNKNOWN,
              MoveDataTag      moveDataTag = NO_MOVE_DATA,
              Pos              move        = Pos::NONE,
              Eval             eval        = VALUE_NONE,
              void            *moveData    = nullptr)
        : position {std::move(pos)}
        , move(move)
        , eval(eval)
        , boardsize {boardsize}
        , rule {rule}
        , result {result}
        , moveDataTag {moveDataTag}
        , moveData(moveData)
    {}
    ~DataEntry()
    {
        switch (moveDataTag) {
        case MoveDataTag::NO_MOVE_DATA: assert(moveData == nullptr); break;
        case MoveDataTag::POLICY_ARRAY_FLOAT: delete[] policyF32; break;
        case MoveDataTag::POLICY_ARRAY_INT16: delete[] policyI16; break;
        default: delete[] multiPvMoves; break;
        }
        moveData = nullptr;
    }
    DataEntry(const DataEntry &other)
        : position {other.position}
        , move {other.move}
        , eval {other.eval}
        , boardsize {other.boardsize}
        , rule {other.rule}
        , result {other.result}
        , moveDataTag {other.moveDataTag}
    {
        switch (moveDataTag) {
        case MoveDataTag::NO_MOVE_DATA: moveData = nullptr; break;
        case MoveDataTag::POLICY_ARRAY_FLOAT: {
            policyF32 = new float[boardsize * boardsize + 1];
            std::copy(other.policyF32, other.policyF32 + boardsize * boardsize + 1, policyF32);
            break;
        }
        case MoveDataTag::POLICY_ARRAY_INT16: {
            policyI16 = new int16_t[boardsize * boardsize + 1];
            std::copy(other.policyI16, other.policyI16 + boardsize * boardsize + 1, policyI16);
            break;
        }
        default: {
            multiPvMoves = new PVMove[other.numExtraPVs()];
            std::copy(other.multiPvMoves, other.multiPvMoves + other.numExtraPVs(), multiPvMoves);
            break;
        }
        }
    }
    DataEntry(DataEntry &&other) noexcept
        : position {std::move(other.position)}
        , move {other.move}
        , eval {other.eval}
        , boardsize {other.boardsize}
        , rule {other.rule}
        , result {other.result}
        , moveDataTag {other.moveDataTag}
    {
        moveData       = other.moveData;
        other.moveData = nullptr;
    }

    /// Returns the current side to move, considering pass moves.
    Color sideToMove() const
    {
        int numPasses = std::count(position.begin(), position.end(), Pos::PASS);
        return (numPasses + position.size()) % 2 == 0 ? BLACK : WHITE;
    }

    /// The number of extra multi-pv moves.
    int numExtraPVs() const
    {
        return moveDataTag >= MULTIPV_BEGIN ? moveDataTag - MULTIPV_BEGIN + 1 : 0;
    }

    /// Get policy target according to pos. If no actual policy target
    /// is available, all policy is concentrated on the best move pos.
    float policyTarget(Pos pos) const
    {
        switch (moveDataTag) {
        case MoveDataTag::NO_MOVE_DATA: return pos == move ? 1.0f : 0.0f;
        case MoveDataTag::POLICY_ARRAY_FLOAT: {
            if (pos == Pos::PASS)
                return policyF32[boardsize * boardsize];
            else
                return policyF32[pos.y() * boardsize + pos.x()];
        }
        case MoveDataTag::POLICY_ARRAY_INT16: {
            if (pos == Pos::PASS)
                return policyI16[boardsize * boardsize] * (1.0f / 32767);
            else
                return policyI16[pos.y() * boardsize + pos.x()] * (1.0f / 32767);
        }
        default: return pos == move ? 1.0f : 0.0f;
        }
    }
};

/// GameEntry struct represents a complete training game that is generated by running
/// a selfplay match or a match between two engines. Positions of each data entry can
/// be acquired from the move sequence outputed by two engines.
/// Note that policy target is not contained in the GameEntry.
struct GameEntry
{
    struct MoveData
    {
        Pos                    move;  // best move output by the engine
        Eval                   eval;  // evaluation of the position (side to move pov)
        DataEntry::MoveDataTag tag = DataEntry::NO_MOVE_DATA;  // type tag of the move data
        union {
            void    *moveData = nullptr;  // pointer to the optional move data
            float   *policyF32;           // row-major policy target of size [boardsize^2+1]
            int16_t *policyI16;           // row-major policy target of size [boardsize^2+1]
            PVMove  *multiPvMoves;        // list of multi-pv moves output by the engine
        };
    };

    std::vector<Pos>      initPosition;  // initial position of the game
    std::vector<MoveData> moveSequence;  // move data sequence that representing a game
    uint8_t               boardsize;     // size of a square board
    Rule                  rule;          // rule of the game entry
    Result                result;        // game result: 0=loss, 1=draw, 2=win (white pov)

    GameEntry()                           = default;
    GameEntry(const GameEntry &other)     = delete;
    GameEntry(GameEntry &&other) noexcept = default;
    ~GameEntry()
    {
        for (auto &moveData : moveSequence) {
            switch (moveData.tag) {
            case DataEntry::MoveDataTag::NO_MOVE_DATA: assert(moveData.moveData == nullptr); break;
            case DataEntry::MoveDataTag::POLICY_ARRAY_FLOAT: delete[] moveData.policyF32; break;
            case DataEntry::MoveDataTag::POLICY_ARRAY_INT16: delete[] moveData.policyI16; break;
            default: delete[] moveData.multiPvMoves; break;
            }
            moveData.moveData = nullptr;
        }
    }
};

}  // namespace Tuning
